import os
import numpy as np
from PIL import Image
from scipy.io import loadmat

from oneflow.utils import data
import random


class StanfordDogs(data.Dataset):
    """Dataset for Stanford Dogs
    """
    urls = {"images.tar":       "http://vision.stanford.edu/aditya86/ImageNetDogs/images.tar",
            "annotation.tar":   "http://vision.stanford.edu/aditya86/ImageNetDogs/annotation.tar",
            "lists.tar":        "http://vision.stanford.edu/aditya86/ImageNetDogs/lists.tar"}

    def __init__(self, root, split='train', s=0.5, download=False, transform=None, target_transform=None):
        self.root = os.path.abspath( os.path.expanduser(root) )
        self.split = split
        self.transform = transform
        self.target_transform = target_transform
        if download:
            self.download()
        list_file = os.path.join(self.root, 'lists', self.split+'_list.mat')
        mat_file = loadmat(list_file)
        size = len(mat_file['file_list'])
        self.files = [str(mat_file['file_list'][i][0][0]) for i in range(size)]

        """if split == 'train':
            self.files = self.sample_by_class(s=s)"""
            
        self.labels = np.array(
            [mat_file['labels'][i][0]-1 for i in range(size)])
        categories = os.listdir(os.path.join(self.root, 'Images'))
        categories.sort()
        self.object_categories = [c[10:] for c in categories]
        print('Stanford Dogs, Split: %s, Size: %d' %
              (self.split, self.__len__()))


    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img = Image.open(os.path.join(self.root, 'Images',
                                      self.files[idx])).convert("RGB")
        lbl = self.labels[idx]
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            lbl = self.target_transform( lbl )
        return img, lbl
    
    def sample_by_class(self, s):
        class_dit = {}
        for file in self.files:
            class_name = file.split('/')[0]
            if class_name not in class_dit.keys():
                class_dit[class_name] = []
            class_dit[class_name].append(file)
        
        files = []
        for key in class_dit.keys():
            n = len(class_dit[key])
            random.shuffle(class_dit[key])
            files += class_dit[key][:int(n*s)]
        return files
